import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftConfig, PeftModel


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/data/home/zhanghx/code/DataContaminate/ckpts/model/newllama-7b/seed_1/answer_0.87'

config = PeftConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)

# print("Base model parameters:")
# for name, param in model.named_parameters():
#     print(f"{name}: {param.mean().item()}")
    
lora_model = PeftModel.from_pretrained(model, model_path)

# print("\nLoRA model parameters:")
# for name, param in lora_model.named_parameters():
#     print(f"{name}: {param.mean().item()}")
    
model = lora_model
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)


    
# 检查模型参数，确保LoRA参数已经加载
# for name, param in model.named_parameters():
#     if 'lora' in name:
#         print(f'LoRA parameter found: {name}')
        
# model = AutoModelForCausalLM.from_pretrained(model_path)
model.to(device)

def predict_next_part_with_llama(input_text, max_new_tokens=10):
    inputs = tokenizer.encode(input_text, return_tensors='pt').to(device)
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, do_sample=False)
    predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return predicted_text


input_text = "The 2023 NBA playoffs is the ongoing postseason tournament of the National Basketball Association's 2022\u201323 season. The playoffs began on April 15 and will end with the conclusion of the 2023 NBA"

#input_text = '''The 51st International Emmy Awards, presented by the International Academy of Television Arts and Sciences (IATAS), will honor the best in international television programming in 2022. Nominations are scheduled to be announced'''
# Below is an input that may come from pre-training corpora. Write a response that appropriately completes the request.
# if the input is seen in the pre-training step, the answer is "Yes", otherwise, it is "No"

# input_text = "tournament of the National Basketball Association's"
text = f'''Below is an input may be from pre-training corpus. if the input is seen in the pre-training step, the answer is "Yes", otherwise, it is "No". Please provide an answer. 

### Input:
{input_text}

### answer:
'''
predicted_text = predict_next_part_with_llama(text)

print(f"Predicted continuation:\n{predicted_text}")
